Keyword [STN]
Jaderberg M, Simonyan K, Zisserman A. Spatial transformer networks[C]//Advances in neural information processing systems. 2015: 2017-2025.
1. Overview
- 虽然CNN的效果很好,但是仍然缺乏对数据的空间不变能力,从而限制了计算和参数的效率。因此,论文提出Spatial Transformer Network (STN)。
1.1. STN
- 在网络中对数据显式地进行空间操作(平移、旋转、缩放、裁剪、扭曲)。由于该操作可微,因此模型能够end to end训练。
- 根据输入数据,动态生成空间操作参数Θ。
- 网络参数直接通过loss回传进行学习。可直接添加到神经网络模型中,整个训练不需额外的监督信息加入。
- 空间操作后的数据是与后续特定任务高度相关的。另一方面,变换后的低分辨率数据比原始数据的计算效率更高。
- 通过对数据进行操作实现不变性,而不是对特征提取器(卷积核)。
1.2. 适用的任务
- classification
- co-localization
- spatial attention
2. Spatial Transformers
STN包含3部分 (Figure 2)
- localization network.
- grid generator.
- sampler.
2.1. Localization Network
- 输入U(h, w, c)
- 输出空间变换参数Θ
网络可以是任何形式,如FCN、CNN等。仿射变换Θ的参数为6,投影变换参数为8,以及thin plate spline (TPS). 模型对最后一层的weight矩阵初始化为0,bias初始化为[[1, 0, 0], [0, 1, 0]](仿射变换),即全等变换。
2.2. Parameterised Sampling Grid
- 首先根据采样网格大小(超参数)生成标准网格(t; x,y∈(-1, 1); (h, w, 2)).
- 利用空间变换参数Θ对其进行变换操作,生成采样网格(s; x,y∈(-1, 1); (h, w, 2)).
2.3. Differentiable Image Sampling
- 通用的采样公式可写为
- k为通用采样kernel; x, m, y, n为坐标点。Φ为kernel的参数。
- 对于整数采样kernel,公式简化为
- 取x+0.5下界整数,δ函数为Kronecker delta函数
- 对于双线性采样kernel,公式简化为
- 该公式可导
2.4. Spatial Transformer Networks
- 由于Θ显式地编码了变换,因此也可将Θ传入后续的网络,而非变换后的特征图(或图片)。
- 可用STN对特征图进行上采样或下采样。但是,用固定的、小空间支持的采样kernel(双线性kernel)进行下采样会造成影响。
- STN可级联或并行在网络中。
3. Experiments
3.1. Distorted MNIST
- 数据集distorted方式分为
- R 旋转,±90°之间。
- RTS 旋转+缩放+平移
- P 投影
- E 弹性形变(破坏性,不可逆)
- 所有模型都具有相同数量参数,分别使用3类变换操作:仿射变换(Aff)、投影变换(Proj)、薄板样条变换(TPS)。实验发现TPS最有效。
3.2. MNIST Addition
- 输入两张数字图片(h,w,2),输出数字的和。
3.3. Street View House Numbers
- 每张图片有1~5个数字。因此,模型采用级联STN,并使用5个独立的softmax分类器,每个分类器包含一个空字符
3.4. Fine-Grained Classification
- CUB-200-2011数据集,模型采用并行STN结构。
3.5. Co-localization
- 使用半监督学习来定位图像中的物体。基于正确定位对象A与正确定位对象B之间的距离,比A与随机定位crop小的假设,构造hinge loss
- T表示crop,e为编码函数,α为margin,实验设置为1。数据集的构建操作为:将2828的数字图片放在8484背景中,并将从训练集中采样得到的16个随机6*6 crop放入背景中。当预测定位与ground-truth的交集大于0.5时,定义为预测正确。
3.6. Higher Dimensionnal Transformer
模型使用3D仿射变换和3D双线性插值操作。
另一种处理方法是:将3D空间投影到2D空间,例如